# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvidia.dali.pipeline import pipeline_def, experimental
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.types import SampleInfo
from nvidia.dali import _conditionals
from nvidia.dali.data_node import DataNode

import numpy as np
import os

from test_utils import check_batch, compare_pipelines
from nose_utils import assert_raises
from test_utils import get_dali_extra_path
from nose2.tools import params

import itertools


def test_condition_stack():
    test_stack = _conditionals._ConditionStack()
    pred_node = DataNode("PredOp")
    pred_nested = DataNode("PredOp2")
    some_op = DataNode("SomeOp")
    some_nested_op = DataNode("SomeOp2")

    # model:
    # if pred_node:
    #     some_op()
    #     if pred_nested:
    #         some_nested_op()

    test_stack.register_data_nodes(pred_node)
    test_stack.register_data_nodes(pred_nested)
    # Both visible in global scope
    assert test_stack._find_closest(pred_node) == 0
    assert test_stack._find_closest(pred_nested) == 0
    # First predicate, no splitting required, as this is the first nesting level
    first_level = test_stack.push_predicate(pred_node)
    assert _conditionals._data_node_repr(pred_node) == _conditionals._data_node_repr(first_level)

    test_stack.track_true_branch()
    test_stack.register_data_nodes(some_op)
    assert test_stack._find_closest(some_op) == 1

    assert test_stack._find_closest(pred_nested) == 0
    assert test_stack.stack_depth() == 2

    true_split = test_stack._realize_split(pred_nested, 0)
    second_level = test_stack.push_predicate(pred_nested)
    # Second predicate require splitting
    assert _conditionals._data_node_repr(true_split) == _conditionals._data_node_repr(second_level)
    test_stack.track_true_branch()
    test_stack.register_data_nodes(some_nested_op)
    assert test_stack._find_closest(some_nested_op) == 2

    # It's already on this level
    assert len(test_stack.top().produced) == 1
    preprocessed = test_stack.preprocess_input(some_nested_op)
    assert _conditionals._data_node_repr(some_nested_op) == _conditionals._data_node_repr(
        preprocessed
    )
    assert len(test_stack.top().produced) == 1

    # This one is not
    assert len(test_stack.top().produced) == 1
    preprocessed = test_stack.preprocess_input(some_op)
    assert _conditionals._data_node_repr(some_op) != _conditionals._data_node_repr(some_nested_op)
    assert len(test_stack.top().produced) == 2

    test_stack.pop()
    test_stack.pop()
    assert len(test_stack.top().produced) == 2


rng = np.random.default_rng()

# Predicates
num_gens = [
    lambda x: np.int32(x.idx_in_batch - 3),
    lambda x: np.int32(-1 if x.idx_in_batch % 2 == 0 else 1),
    lambda x: np.int32((x.idx_in_batch % 3 == 0) - 1),
    lambda _: np.int32(1),
    lambda _: np.int32(0),
    lambda _: np.int32(-1),
    lambda _: rng.choice([np.int32(-2), np.int32(0), np.int32(2)]),
]

pred_gens = [
    lambda x: np.array(x.idx_in_batch < 3),
    lambda x: np.array(x.idx_in_batch % 2 == 0),
    lambda x: np.array(x.idx_in_batch % 3 == 0),
    lambda x: np.array((x.idx_in_batch + (x.iteration % 2)) % 2 == 0),
    lambda _: np.array(False),
    lambda _: rng.choice([np.array(True), np.array(False)]),
]

input_gens = [lambda x: np.array(0), lambda x: np.array(x.idx_in_epoch)]


def generic_execute(function, input_gen_list, optional_params=None):
    """Given a Python `function` (taking some positional arguments) and a list of sample generators,
    execute the function twice on batches of data generated by the generator and compare to test
    the conditional execution.

    The function is executed both as a:
    * DALI Pipeline with conditional execution enabled. External source nodes are passed
      as positional parameters and fed with the generated batches.
    * Regular function, where we pass the batches sample-by-sample to build output batches.

    Parameters
    ----------
    function : callable
        function used for testing
    input_gen_list : list of sample generators
        Possibly a stateful generator
    optional_params : list of dictionaries, optional
        Optional kwargs for external source associated with given input position, by default None
    """
    if optional_params is None:
        optional_params = [{} for _ in input_gen_list]
    assert len(input_gen_list) == len(optional_params), (
        "Optional param should be provided for" " every external source node."
    )
    bs = 10
    iters = 5
    kwargs = {
        "batch_size": bs,
        "num_threads": 4,
        "device_id": 0,
        "prefetch_queue_depth": 1,  # so that it's easier to use external source
    }

    # Prepare external source nodes with placeholder names, convert
    es_inputs = [
        fn.external_source(name=f"input_{i}", **params) for i, params in enumerate(optional_params)
    ]

    pipeline_definition = pipeline_def(enable_conditionals=True)(function)

    def gen_batch(generator, bs, iter):
        return [generator(SampleInfo(bs * iter + i, i, iter, 0)) for i in range(bs)]

    pipe = pipeline_definition(*es_inputs, **kwargs)

    for iter in range(iters):
        batches = [gen_batch(gen, bs, iter) for gen in input_gen_list]
        for i, batch in enumerate(batches):
            pipe.feed_input(f"input_{i}", batch)

        outputs = pipe.run()

        baseline_outputs = []
        for inputs_i in zip(*batches):
            outputs_i = function(*inputs_i)
            # make it a tad more generic
            if not isinstance(outputs_i, tuple):
                outputs_i = (outputs_i,)
            baseline_outputs.append(outputs_i)

        # Repack list of tuples into tuple of lists.
        baseline_outputs = tuple(zip(*baseline_outputs))
        # make the elements actually lists:
        baseline_outputs = (list(baseline) for baseline in baseline_outputs)

        for out, baseline in zip(outputs, baseline_outputs):
            check_batch(out, baseline, bs)


# Tests below are ported from dali/test/python/autograph/converters/test_control_flow.py


@params(*num_gens)
def test_basic(num_gen):
    def f(n):
        a = np.int32(0)
        b = np.int32(0)
        if n > 0:
            a = -n
        else:
            b = 2 * n
        return a, b

    generic_execute(f, [num_gen])


@params(*num_gens)
def test_complex_outputs(num_gen):
    class DataClass(object):
        def __init__(self, a, b):
            self.a = a
            self.b = b

    def f(n, obj):
        obj.a = np.int32(0)
        obj.b = np.int32(0)
        if n > 0:
            obj.a = -n
        else:
            obj.b = 2 * n
        return obj.a, obj.b

    generic_execute(lambda input: f(input, DataClass(np.int32(0), np.int32(0))), [num_gen])


@params(*num_gens)
def test_single_output(num_gen):
    def f(n):
        if n > 0:
            n = -n
        return n

    generic_execute(f, [num_gen])


@params(*num_gens)
def test_unbalanced(num_gen):
    def f(n):
        if n > 0:
            n = np.int32(3)
        return n

    generic_execute(f, [num_gen])


@params(*num_gens)
def test_local_var(num_gen):
    def f(n):
        if n > 0:
            b = np.int32(4)
            n = b + 1
        return n

    generic_execute(f, [num_gen])


@params(*num_gens)
def test_local_remains_local(num_gen):
    def f(n):
        if n > 0:
            b = np.int32(4)
            n = b + 1
        return n

    generic_execute(f, [num_gen])


@params(*num_gens)
def test_no_outputs(num_gen):
    def f(n):
        if n > 0:
            b = np.int32(4)  # pylint:disable=unused-variable # noqa: F841
        return n

    generic_execute(f, [num_gen])


@params(*num_gens)
def test_created_outputs(num_gen):
    def f(i):
        if i == 0:
            result = i - 1
        else:
            result = i + 1
        return result

    generic_execute(f, [num_gen])


# Simple cases, where we produce new data node in the branch


@params(*num_gens)
def test_one_branch_new_node(num_gen):
    def f(n):
        result = n * 0
        if n >= 0:
            result = n + 10
        return result

    generic_execute(f, [num_gen])


@params(*num_gens)
def test_both_branches_new_node(num_gen):
    def f(n):
        if n >= 0:
            result = n + 10
        else:
            result = n - 10
        return result

    generic_execute(f, [num_gen])


@params(*num_gens)
def test_chain_branches_new_node(num_gen):
    def f(n):
        if n == 0:
            result = n + 10
        elif n > 0:
            result = n + 100
        else:
            result = n - 50
        return result

    generic_execute(f, [num_gen])


# Cases where we do only assignment and no new node is produced within branch, so we need to
# detect usage in other way than looking at operator inputs


@params(*pred_gens)
def test_one_branch_only_assign(pred):
    def f(pred, base, true_branch):
        result = base
        if pred:
            result = true_branch
        return result

    generic_execute(f, [pred, lambda _: np.int32(42), lambda _: np.int32(7)])


@params(*pred_gens)
def test_both_branches_only_assign(pred):
    def f(pred, true_branch, false_branch):
        if pred:
            result = true_branch
        else:
            result = false_branch
        return result

    generic_execute(f, [pred, lambda _: np.int32(6), lambda _: np.int32(9)])


@params(*itertools.product(pred_gens, pred_gens))
def test_chain_branches_only_assign(pred_1, pred_2):
    def f(pred_1, pred_2, true_branch, elif_branch, else_branch):
        if pred_1:
            result = true_branch
        elif pred_2:
            result = elif_branch
        else:
            result = else_branch
        return result

    generic_execute(
        f, [pred_1, pred_2, lambda _: np.int32(42), lambda _: np.int32(6), lambda _: np.int32(9)]
    )


# More ifs - nesting and sequences


@params(*itertools.product(["cpu", "gpu"], input_gens, pred_gens, pred_gens))
def test_consecutive(dev, input, pred_0, pred_1):
    def f(input, pred_0, pred_1):
        if pred_0:
            output = input + 1
        else:
            output = input + 2

        if pred_1:
            output2 = output + 3
        else:
            output2 = output + 4
        return output, output2

    generic_execute(f, [input, pred_0, pred_1], [{"device": dev}, {}, {}])


@params(*itertools.product(["cpu", "gpu"], input_gens, pred_gens, pred_gens))
def test_nested(dev, input, pred_0, pred_1):
    def f(input, pred_0, pred_1):
        if pred_0:
            if pred_1:
                output = input + 10
            else:
                output = input + 200
        else:
            output = input + 3000
        return output

    generic_execute(f, [input, pred_0, pred_1], [{"device": dev}, {}, {}])


@params(*itertools.product(["cpu", "gpu"], input_gens, pred_gens, pred_gens))
def test_nested_with_assignment(dev, input, pred_0, pred_1):
    def f(input, pred_0, pred_1):
        to_assign = input * -5
        if pred_0:
            if pred_1:
                output = input + 10
            else:
                output = to_assign
        else:
            output = input + 3000
        return output

    generic_execute(f, [input, pred_0, pred_1], [{"device": dev}, {}, {}])


@params(*itertools.product(["cpu", "gpu"], input_gens, num_gens))
def test_multiple_nests(dev, input, num):
    def f(input, num):
        if num == -2:
            if num == -1:
                if num == 0:
                    if num == 1:
                        if num == 2:
                            if num > 3:
                                output = input - 100
                            else:
                                output = input + 100
                        else:
                            output = input - 200
                    else:
                        output = input + 400
                else:
                    output = input - 800
            else:
                output = input + 1600
        else:
            output = input - 3200
        return output

    generic_execute(f, [input, num], [{"device": dev}, {}])


# Compare pure Split/Merge operators with if statement
def _impl_against_split_merge(base_additional_kwargs={}, conditional_additional_kwargs={}):
    test_data_root = get_dali_extra_path()
    caffe_db_folder = os.path.join(test_data_root, "db", "lmdb")

    bs = 10
    iters = 5
    kwargs = {"batch_size": bs, "num_threads": 4, "device_id": 0, "seed": 42}

    @experimental.pipeline_def(**kwargs, **base_additional_kwargs)
    def regular_pipe():
        encoded, _ = fn.readers.caffe(path=caffe_db_folder, seed=7)
        decoded = fn.decoders.image(encoded, device="mixed")
        pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL, seed=8)
        true, false = fn._conditional.split(decoded, predicate=pred)
        output_true = fn.rotate(true, angle=30)
        output_false = fn.flip(false, horizontal=True)
        return fn._conditional.merge(output_true, output_false, predicate=pred)

    @experimental.pipeline_def(enable_conditionals=True, **kwargs, **conditional_additional_kwargs)
    def conditional_pipe():
        encoded, _ = fn.readers.caffe(path=caffe_db_folder, seed=7)
        decoded = fn.decoders.image(encoded, device="mixed")
        pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL, seed=8)
        if pred:
            output = fn.rotate(decoded, angle=30)
        else:
            output = fn.flip(decoded, horizontal=True)
        return output

    pipes = [regular_pipe(), conditional_pipe()]
    compare_pipelines(*pipes, bs, iters)


def test_against_split_merge():
    _impl_against_split_merge()


# Compare pure Split/Merge operators with if statement to see if DataNodes produced by `.gpu()`
# are registered
def _impl_dot_gpu(base_additional_kwargs={}, conditional_additional_kwargs={}):
    test_data_root = get_dali_extra_path()
    caffe_db_folder = os.path.join(test_data_root, "db", "lmdb")

    bs = 10
    iters = 5
    kwargs = {"batch_size": bs, "num_threads": 4, "device_id": 0, "seed": 42}

    @experimental.pipeline_def(**kwargs, **base_additional_kwargs)
    def regular_pipe():
        encoded, _ = fn.readers.caffe(path=caffe_db_folder, seed=1)
        decoded = fn.decoders.image(encoded, device="cpu")
        pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL, seed=2)
        true, false = fn._conditional.split(decoded, predicate=pred)
        output_true = fn.rotate(true.gpu(), angle=30)
        output_false = fn.flip(false, horizontal=True).gpu()
        return fn._conditional.merge(output_true, output_false, predicate=pred)

    @experimental.pipeline_def(enable_conditionals=True, **kwargs, **conditional_additional_kwargs)
    def conditional_pipe():
        encoded, _ = fn.readers.caffe(path=caffe_db_folder, seed=1)
        decoded = fn.decoders.image(encoded)
        pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL, seed=2)
        if pred:
            decoded_gpu_true = decoded.gpu()
            # The `decoded` will be split as we look it up in a scope of a branch,
            # so the new node is built based on that split batch
            if not conditional_additional_kwargs:
                assert "__Split" in decoded_gpu_true.name
            output = fn.rotate(decoded_gpu_true, angle=30)
        else:
            output = fn.flip(decoded, name="flip_in_else", horizontal=True)
            output = output.gpu()
            # here we crate new node based on the one already produced in this scope,
            # so the source name is kept
            if not conditional_additional_kwargs:
                assert output.name == "flip_in_else"
        return output

    pipes = [regular_pipe(), conditional_pipe()]
    compare_pipelines(*pipes, bs, iters)


def test_dot_gpu():
    _impl_dot_gpu()


# Test if operators without positional inputs but with argument inputs are correctly handled
# in the split/merge - so they are tracked in the local scope.


def _impl_arg_inputs_scoped_tracking(global_additional_kwargs={}, scoped_additional_kwargs={}):
    test_data_root = get_dali_extra_path()
    caffe_db_folder = os.path.join(test_data_root, "db", "lmdb")

    bs = 10
    iters = 5
    kwargs = {"batch_size": bs, "num_threads": 4, "device_id": 0, "seed": 42}

    @experimental.pipeline_def(enable_conditionals=True, **kwargs, **global_additional_kwargs)
    def global_transform_pipe():
        encoded, _ = fn.readers.caffe(path=caffe_db_folder)
        decoded = fn.decoders.image(encoded, device="mixed")
        pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL, seed=6)
        angle = fn.random.uniform(values=[10, 20, 30], seed=7)
        rotate_transform = fn.transforms.rotation(angle=angle)
        if pred:
            output = fn.warp_affine(decoded, matrix=rotate_transform)
        else:
            output = decoded
        return output

    @experimental.pipeline_def(enable_conditionals=True, **kwargs, **scoped_additional_kwargs)
    def scoped_transform_pipe():
        encoded, _ = fn.readers.caffe(path=caffe_db_folder)
        decoded = fn.decoders.image(encoded, device="mixed")
        pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL, seed=6)
        angle = fn.random.uniform(values=[10, 20, 30], seed=7)
        if pred:
            # This is the crux of the test, the transforms.rotate has no positional inputs,
            # but it has a DataNode argument input - it should detect it as produced in this scope.
            rotate_transform = fn.transforms.rotation(angle=angle)
            output = fn.warp_affine(decoded, matrix=rotate_transform)
        else:
            output = decoded
        return output

    pipes = [global_transform_pipe(), scoped_transform_pipe()]
    compare_pipelines(*pipes, bs, iters)


def test_arg_inputs_scoped_tracking():
    _impl_arg_inputs_scoped_tracking()


def _impl_arg_inputs_scoped_uninitialized(additional_kwargs={}):
    test_data_root = get_dali_extra_path()
    caffe_db_folder = os.path.join(test_data_root, "db", "lmdb")
    bs = 10
    kwargs = {"batch_size": bs, "num_threads": 4, "device_id": 0}

    @experimental.pipeline_def(enable_conditionals=True, **kwargs, **additional_kwargs)
    def scoped_transform_pipe():
        encoded, _ = fn.readers.caffe(path=caffe_db_folder)
        decoded = fn.decoders.image(encoded, device="mixed")
        pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL)
        angle = fn.random.uniform(values=[10, 20, 30])
        if pred:
            rotate_transform = fn.transforms.rotation(angle=angle)
            output = fn.warp_affine(decoded, matrix=rotate_transform)
        else:
            output = decoded
        # Check that the rotate_transform is indeed local to the branch by trying to return it
        # and generating uninitialized error.
        return output, rotate_transform

    with assert_raises(
        RuntimeError,
        glob=(
            "Encountered inconsistent outputs out of the `if/else` control flow"
            " statement. Variables need to be initialized in every code path"
            " (both `if` branches). Variable 'rotate_transform' must also be"
            " initialized in the `else` branch."
        ),
    ):
        pipe = scoped_transform_pipe()
        pipe.run()


def test_arg_inputs_scoped_uninitialized():
    _impl_arg_inputs_scoped_uninitialized()


# Unified return tests - TODO(klecki)

# Generator tests, remove the random predicate to test the same predicate in both pipelines.


@params(*(pred_gens[:-1]))
def _impl_generators(pred, base_additional_kwargs={}, conditional_additional_kwargs={}):
    test_data_root = get_dali_extra_path()
    caffe_db_folder = os.path.join(test_data_root, "db", "lmdb")

    bs = 10
    iters = 5
    kwargs = {"batch_size": bs, "num_threads": 4, "device_id": 0, "seed": 42}

    @experimental.pipeline_def(**kwargs, **base_additional_kwargs)
    def baseline_pipe():
        encoded, _ = fn.readers.caffe(path=caffe_db_folder, seed=10)
        rand = fn.random.uniform(seed=11)
        predicate = fn.external_source(source=pred, batch=False)
        true_encoded, _ = fn._conditional.split(encoded, predicate=predicate)
        true_rand, _ = fn._conditional.split(rand, predicate=predicate)
        # TODO(klecki): Debug mode currently requires explicit constants instantiation
        if base_additional_kwargs:
            u8_zeros = types.Constant(np.uint8([0]), device="cpu")
            f32_zeros = types.Constant(np.float32(0.0), device="cpu")
        else:
            u8_zeros = np.uint8([0])
            f32_zeros = np.float32(0.0)
        _, false_u8 = fn._conditional.split(u8_zeros, predicate=predicate)
        _, false_f32 = fn._conditional.split(f32_zeros, predicate=predicate)
        encoded_out = fn._conditional.merge(true_encoded, false_u8, predicate=predicate)
        rand_out = fn._conditional.merge(true_rand, false_f32, predicate=predicate)
        return encoded_out, rand_out

    @experimental.pipeline_def(enable_conditionals=True, **kwargs, **conditional_additional_kwargs)
    def conditional_pipe():
        predicate = fn.external_source(source=pred, batch=False)
        # Generators work by running in top scope and splitting for particular nesting
        if predicate:
            encoded_out, _ = fn.readers.caffe(path=caffe_db_folder, seed=10)
            rand_out = fn.random.uniform(seed=11)
        else:
            encoded_out = types.Constant(np.uint8([0]), device="cpu")
            rand_out = types.Constant(np.float32(0.0), device="cpu")
        return encoded_out, rand_out

    pipes = [baseline_pipe(), conditional_pipe()]
    compare_pipelines(*pipes, bs, iters)


@params(*(pred_gens[:-1]))
def test_generators(pred):
    _impl_generators(pred)


# Mismatched branches test (uninitialized values)


def _impl_uninitialized(additional_kwargs={}):
    bs = 10
    kwargs = {
        "batch_size": bs,
        "num_threads": 4,
        "device_id": 0,
    }

    @experimental.pipeline_def(enable_conditionals=True, **kwargs, **additional_kwargs)
    def one_branch():
        pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL)
        if pred:
            output = fn.random.uniform()
        return output

    with assert_raises(
        RuntimeError,
        glob=(
            "Encountered inconsistent outputs out of the `if/else` control flow"
            " statement. Variables need to be initialized in every code path"
            " (both `if` branches). Variable 'output' must also be initialized"
            " in the `else` branch."
        ),
    ):
        p = one_branch()
        p.run()

    @experimental.pipeline_def(enable_conditionals=True, **kwargs)
    def one_return():
        pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL)
        if pred:
            return fn.random.uniform()

    with assert_raises(
        RuntimeError,
        glob=(
            "Encountered inconsistent outputs out of the `if/else` control flow"
            " statement. Variables need to be initialized in every code path"
            " (both `if` branches). The `else` branch must also have a return"
            " statement."
        ),
    ):
        p = one_return()
        p.run()


def test_uninitialized():
    _impl_uninitialized()


def _tensor_arg_permute_batch_params():
    batch_sizes = [1, 5, 8]
    inp0 = [
        [np.full((2, 2), i, dtype=np.float32) for i in range(batch_size)]
        for batch_size in batch_sizes
    ]
    mask_batches = [
        np.array([i % 2 for i in range(batch_size)], dtype=bool) for batch_size in batch_sizes
    ]
    kwarg_batches = [np.array([pred for pred in mask], dtype=np.int32) for mask in mask_batches]
    return (inp0,), mask_batches, {"indices": kwarg_batches}


def _tensor_arg_transform_per_dim_params(arg_name):
    def inner():
        batch_sizes = [5, 1, 2, 8]
        mask_batches = [
            np.array([i % 2 for i in range(batch_size)], dtype=bool) for batch_size in batch_sizes
        ]
        kwarg_batches = [
            np.array([[pred, pred] for pred in mask], dtype=np.float32) for mask in mask_batches
        ]
        return tuple(), mask_batches, {arg_name: kwarg_batches}

    return inner


def _tensor_arg_rotate_params():
    batch_sizes = [3, 1, 2, 4]
    mask_batches = [
        np.array([i % 2 for i in range(batch_size)], dtype=bool) for batch_size in batch_sizes
    ]
    kwarg_batches = [
        np.array([10 + 45 * pred for pred in mask], dtype=np.float32) for mask in mask_batches
    ]
    return tuple(), mask_batches, {"angle": kwarg_batches}


def _tensor_arg_roi_random_crop_params():
    batch_sizes = [1, 2, 7, 3]
    crop_shape = [
        [np.array([100 * i + 50, 200 * i + 50, 3], dtype=np.int32) for i in range(batch_size)]
        for batch_size in batch_sizes
    ]
    roi_start = [
        [np.array([sample[0] // 2, sample[1] // 2, sample[2]], dtype=np.int32) for sample in batch]
        for batch in crop_shape
    ]
    mask_batches = [
        np.array([i % 2 for i in range(batch_size)], dtype=bool) for batch_size in batch_sizes
    ]
    return (
        tuple(),
        mask_batches,
        {"crop_shape": crop_shape, "roi_start": roi_start, "roi_end": crop_shape},
    )


def _tensor_arg_shape_kwarg():
    batch_sizes = [1, 2, 3, 16, 5]
    shape = [
        [np.array([1 + 3 * i, 2 * (i + 1) - 1], dtype=np.int32) for i in range(batch_size)]
        for batch_size in batch_sizes
    ]
    mask_batches = [
        np.array([i % 2 for i in range(batch_size)], dtype=bool) for batch_size in batch_sizes
    ]
    return tuple(), mask_batches, {"shape": shape}


# Test operators that infer their batch sizes from the tensor argument inputs
@params(
    fn.permute_batch,
    fn.roi_random_crop,
    fn.transforms.crop,
    fn.transforms.scale,
    fn.transforms.shear,
    fn.transforms.translation,
    fn.transforms.rotation,
    fn.random.uniform,
    fn.random.normal,
    fn.random.coin_flip,
)
def test_named_tensor_arguments(op):
    ops2params = {
        fn.permute_batch: _tensor_arg_permute_batch_params,
        fn.roi_random_crop: _tensor_arg_roi_random_crop_params,
        fn.transforms.crop: _tensor_arg_transform_per_dim_params("from_start"),
        fn.transforms.scale: _tensor_arg_transform_per_dim_params("scale"),
        fn.transforms.shear: _tensor_arg_transform_per_dim_params("angles"),
        fn.transforms.translation: _tensor_arg_transform_per_dim_params("offset"),
        fn.transforms.rotation: _tensor_arg_rotate_params,
        fn.random.uniform: _tensor_arg_shape_kwarg,
        fn.random.normal: _tensor_arg_shape_kwarg,
        fn.random.coin_flip: _tensor_arg_shape_kwarg,
    }

    def dummy_source(batches):
        def cb():
            for batch in batches:
                yield batch

        return cb

    def get_pipeline(op, args_batches, mask_batches, kwargs_batches, num_threads=4, device_id=0):
        max_batch_size = max(len(batch) for batch in mask_batches)

        @pipeline_def(batch_size=max_batch_size, num_threads=num_threads, device_id=device_id)
        def split_pipeline():
            args = [fn.external_source(dummy_source(arg_batches)) for arg_batches in args_batches]
            mask = fn.external_source(dummy_source(mask_batches))
            kwargs = {
                kwarg_name: fn.external_source(dummy_source(batches))
                for kwarg_name, batches in kwargs_batches.items()
            }
            kwargs_split = {
                kwarg_name: fn._conditional.split(batch, predicate=mask)
                for kwarg_name, batch in kwargs.items()
            }
            split_args = [fn._conditional.split(arg, predicate=mask) for arg in args]
            left_args = [left_arg for left_arg, _ in split_args]
            right_args = [right_arg for _, right_arg in split_args]
            left = op(
                *left_args,
                **{kwarg_name: left_kwarg for kwarg_name, (left_kwarg, _) in kwargs_split.items()},
            )
            right = op(
                *right_args,
                **{
                    kwarg_name: right_kwarg for kwarg_name, (_, right_kwarg) in kwargs_split.items()
                },
            )
            batch = fn._conditional.merge(left, right, predicate=mask)
            return batch

        return split_pipeline()

    args_batches, mask_batches, kwargs_batches = ops2params[op]()
    pipe = get_pipeline(
        op=op, args_batches=args_batches, mask_batches=mask_batches, kwargs_batches=kwargs_batches
    )
    for _ in range(len(mask_batches)):
        pipe.run()


@params((32, 0), (32, 1), (32, 7), (32, 16), (32, 31), (32, 32))
def test_simple_batch_permute(batch_size, permute_prefix):
    """
    Permute `permute_prefix` of the batch and leave the remaining part untouched
    """

    @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4, enable_conditionals=True)
    def pipeline():
        sample_idx = fn.external_source(
            lambda sample_info: np.array(sample_info.idx_in_batch, dtype=np.int32), batch=False
        )
        if sample_idx < permute_prefix:
            sample_idx = fn.batch_permutation()
        return sample_idx

    p = pipeline()

    for _ in range(3):
        (sample_indices,) = p.run()
        sample_indices = [np.array(sample).item() for sample in sample_indices]
        permuted_prefix = sample_indices[:permute_prefix]
        expected_prefix = list(range(permute_prefix))
        untouched_suffix = sample_indices[permute_prefix:]
        expected_suffix = list(range(permute_prefix, batch_size))
        assert sorted(permuted_prefix) == expected_prefix, (
            f"expected permuted prefix `{permuted_prefix}` to contain the "
            f"following samples {expected_prefix}"
        )
        assert untouched_suffix == expected_suffix, (
            f"expected untouched suffix `{untouched_suffix}` to be exactly " f"{expected_suffix}"
        )


# the fn.batch_permutation is special operator in the context of the
# conditional execution, as it relies on the local batch size to generate
# valid permutation for the split batch, while it does not accept explicitly
# any arguments to infer the local batch size from
@params((7, 1), (1, 1), (16, 2), (16, 3), (101, 3))
def test_batch_permutation(batch_size, num_split_level):
    """
    Split the batch into `2**num_split_level` random groups and permute
    the groups separately
    """

    def split_and_permute(batch, num_levels, group=0):
        assert num_levels >= 0
        if num_levels == 0:
            return fn.permute_batch(batch, indices=fn.batch_permutation()), group
        else:
            if fn.random.coin_flip():
                return split_and_permute(batch, num_levels - 1, group)
            else:
                return split_and_permute(batch, num_levels - 1, group + 2 ** (num_levels - 1))

    @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4, enable_conditionals=True)
    def pipeline():
        sample_idx = fn.external_source(
            lambda sample_info: np.array(sample_info.idx_in_batch), batch=False
        )
        sample_idx, group = split_and_permute(sample_idx, num_split_level)
        return sample_idx, group

    p = pipeline()

    for _ in range(3):
        sample_indices, groups = p.run()
        sample_idx = [np.array(sample).item() for sample in sample_indices]
        group = [np.array(sample).item() for sample in groups]
        groups = {i: ([], []) for i in range(2**num_split_level)}

        for group_idx in range(batch_size):
            got, expected = groups[group[group_idx]]
            got.append(sample_idx[group_idx])
            expected.append(group_idx)

        for group_idx, (got, expected) in groups.items():
            assert sorted(got) == expected, f"{group_idx}: {got} vs {expected}"


def test_error_condition():
    kwargs = {
        "enable_conditionals": True,
        "batch_size": 10,
        "num_threads": 4,
        "device_id": 0,
    }

    @experimental.pipeline_def(**kwargs)
    def gpu_condition():
        pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL)
        # We have to create a new, pure GPU node, otherwise we still find the CPU node.
        if pred.gpu() | False:
            output = np.array(1)
        else:
            output = np.array(0)
        return output

    # TODO(klecki): Extend the error checking so we can provide better error message here.
    with assert_raises(
        ValueError,
        glob=(
            "Invalid device \"gpu\" for argument 'predicate' of operator "
            "'nvidia.dali.fn._conditional.split'."
        ),
    ):
        pipe = gpu_condition()
        print(pipe.run())

    @experimental.pipeline_def(**kwargs)
    def non_scalar_condition():
        pred = fn.random.coin_flip(dtype=types.DALIDataType.BOOL)
        stacked = fn.stack(pred, pred)
        if stacked:
            output = np.array(1)
        else:
            output = np.array(0)
        return output

    with assert_raises(
        RuntimeError,
        glob=(
            "Conditions inside `if` statements are restricted to scalar"
            " (0-d tensors) inputs, that are placed on CPU."
            " Got a 1-d input as a condition of the `if` statement."
        ),
    ):
        pipe = non_scalar_condition()
        pipe.run()


boolable_types = [
    bool,
    np.uint8,
    np.uint16,
    np.uint32,
    np.uint64,
    np.int8,
    np.int16,
    np.int32,
    np.int64,
    np.float16,
    np.float32,
    np.float64,
]


@params(*boolable_types)
def test_predicate_any_type(input_type):
    batch_size = 10
    kwargs = {
        "enable_conditionals": True,
        "batch_size": batch_size,
        "num_threads": 4,
        "device_id": 0,
    }

    def get_truthy_falsy(sample_info):
        if sample_info.idx_in_batch < batch_size / 2:
            return np.array(7, dtype=input_type)
        else:
            return np.array(0, dtype=input_type)

    @experimental.pipeline_def(**kwargs)
    def non_bool_predicate():
        predicate = fn.external_source(source=get_truthy_falsy, batch=False)
        if predicate:
            output = types.Constant(np.array(42), device="cpu")
        else:
            output = types.Constant(np.array(0), device="cpu")
        return output

    pipe = non_bool_predicate()
    (batch,) = pipe.run()

    target = [42 if i < batch_size / 2 else 0 for i in range(batch_size)]
    check_batch(batch, target)


def test_data_node_if_error():
    batch_size = 10
    kwargs = {
        "enable_conditionals": False,
        "batch_size": batch_size,
        "num_threads": 4,
        "device_id": 0,
    }

    @pipeline_def(**kwargs)
    def pipeline():
        predicate = fn.random.coin_flip()
        if predicate:
            output = types.Constant(np.array(42), device="cpu")
        else:
            output = types.Constant(np.array(0), device="cpu")
        return output

    with assert_raises(
        TypeError,
        glob='"DataNode" was used in conditional context*'
        " To use conditional execution via `if` statements you need to specify"
        " `enable_conditionals=True` in `@nvidia.dali.pipeline_def` decorator*",
    ):
        pipe = pipeline()
        pipe.run()


def test_sanity_enable_conditionals():
    batch_size = 10
    kwargs = {
        "batch_size": batch_size,
        "num_threads": 4,
        "device_id": 0,
    }

    # Use no parenthesis version:
    @pipeline_def
    def pipeline(a, b):
        predicate = fn.random.coin_flip()
        if predicate:
            output = types.Constant(np.array(a), device="cpu")
        else:
            output = types.Constant(np.array(b), device="cpu")
        return output

    pipe = pipeline(10, enable_conditionals=True, b=4, **kwargs)
    pipe.run()


def test_multiple_input_source():
    batch_size = 16

    @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4, enable_conditionals=True)
    def pipeline():
        sample_idx = fn.external_source(
            lambda sample_info: np.array(sample_info.idx_in_batch, dtype=np.int32), batch=False
        )

        const_42 = types.Constant(np.uint8([42]), device="cpu")
        if sample_idx < batch_size / 2:
            out_42_scoped, out_idx_scoped = fn.copy(
                [const_42 + types.Constant(0, dtype=types.UINT8), sample_idx]
            )
        else:
            out_42_scoped = types.Constant(np.uint8([0]), device="cpu")
            out_idx_scoped = types.Constant(np.int32(0), device="cpu")

        return out_42_scoped, out_idx_scoped

    pipe = pipeline()
    for _ in range(4):
        out_42, out_idx = pipe.run()
        check_batch(out_42, [[42] if i < (batch_size / 2) else [0] for i in range(batch_size)])
        check_batch(out_idx, [i if i < (batch_size / 2) else 0 for i in range(batch_size)])


def test_exception_explanation():
    def throwing_helper():
        raise ValueError("I am throwing")

    @pipeline_def(batch_size=10, device_id=0, num_threads=4, enable_conditionals=True)
    def throwing_pipeline():
        if fn.random.coin_flip():
            x = np.full((1, 1), 10)
        else:
            x = throwing_helper()
        return x

    # Check that the error message contains the user code from before the autograph translation
    # as an explanation.
    with assert_raises(
        ValueError, glob="*in user code*in throwing_pipeline*in throwing_helper*ValueError"
    ):
        _ = throwing_pipeline()
